[Speculative Decoding] Refactor EAGLE3 training to YAML-based config and recipe system#1134
[Speculative Decoding] Refactor EAGLE3 training to YAML-based config and recipe system#1134
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRefactors speculative-decoding examples to use OmegaConf/YAML configs: scripts and tests now produce or reference a base YAML plus dotlist overrides; Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Launch as launch_train.sh
participant Config as YAML (_base_eagle3 + dotlist)
participant Main as main.py
participant Convert as mtsp.convert
participant Trainer as transformers.Trainer
User->>Launch: run --config <yaml> + overrides
Launch->>Config: reference base YAML + pass dotlist overrides
Launch->>Main: accelerate launch main.py --config <yaml> <overrides>
Main->>Config: OmegaConf.load/merge via _load_config()
Main->>Main: flatten into model/data/training dicts + extract eagle dict
Main->>Convert: mtsp.convert(model, [("eagle", eagle_cfg)])
Main->>Trainer: Trainer.train()
Trainer-->>Main: training complete
Main-->>User: save_model / export artifacts
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## main #1134 +/- ##
===========================================
- Coverage 70.19% 54.52% -15.68%
===========================================
Files 230 348 +118
Lines 26044 39778 +13734
===========================================
+ Hits 18281 21688 +3407
- Misses 7763 18090 +10327
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)
258-259:⚠️ Potential issue | 🔴 CriticalAdd
weights_only=Truetotorch.load()call for security.The
torch.load(data_args.draft_vocab_cache)at line 258 does not specifyweights_only=True, which allows arbitrary code execution from malicious pickle files. Sinced2tis a pure tensor (int64),weights_only=Trueis both safe and compatible.Proposed fix
- model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) + model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 258 - 259, The torch.load call that assigns model.eagle_module.d2t from data_args.draft_vocab_cache should pass weights_only=True to avoid executing pickled code; update the load call in the code that sets model.eagle_module.d2t to use torch.load(data_args.draft_vocab_cache, weights_only=True) so only tensor data is deserialized.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 49-79: The accelerate launch invocation currently interpolates
unquoted variables into sh -c and omits --num_processes on single-node runs; to
fix, build the command as a bash array (e.g., CMD=()) instead of a single sh -c
string, append required flags using the existing
MULTI_NODE_ARGS/MODEL_ARG/TOTAL_GPU symbols (ensure MULTI_NODE_ARGS always
includes "--num_processes $TOTAL_GPU" even for single-node), and then run the
launch with "${CMD[@]}" so that $CONFIG_FILE, $MODEL, $HEAD_NODE_IP and other
variables are safely quoted and preserved without word-splitting or accidental
expansion.
In `@examples/speculative_decoding/main.py`:
- Line 111: The metadata help string for the dataclass field ar_validate_steps
is incomplete; update the metadata["help"] for ar_validate_steps to a full,
descriptive sentence (e.g., "Number of autoregressive validation steps to run
during evaluation" or similar) so users understand its purpose; locate the
ar_validate_steps field definition and replace the truncated help text with the
completed description.
In `@examples/speculative_decoding/train_eagle3_and_export.sh`:
- Around line 43-48: train_config.yaml is missing the base model identifier so
the generated YAML is not replayable; update the code that writes YAML_FILE
(train_config.yaml) to include the model_name_or_path value (the model used via
the --model override) under model: (e.g., model_name_or_path: "<value>") so the
config fully captures the runtime model selection; ensure the string comes from
the same variable/arg used to parse the --model override and is written when
creating YAML_FILE (preserving YAML_FILE, OUTPUT_DIR, and model_name_or_path
references).
In `@modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml`:
- Around line 3-6: The recipe currently enables trust_remote_code by default in
the model block (fields model_name_or_path: moonshotai/Kimi-K2.5 and
trust_remote_code: true); change that default to false and instead
document/require an explicit opt-in (e.g., a commented flag or
environment-driven toggle) so users must consciously enable trust_remote_code
for the Kimi recipe; update any README or inline comment near the model
configuration and/or the use_fake_base_for_offline handling so it explains how
to opt in (enable trust_remote_code) when the user intentionally trusts the
model's custom HF code.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml`:
- Around line 4-6: Replace the unsafe default by changing the YAML key
trust_remote_code from true to false in the model block of the Llama offline
recipe (the block containing model_name_or_path: meta-llama/Llama-3.2-1B);
update the value so the recipe does not silently enable remote code execution
and leave a brief comment if you want to document that users must opt-in to
enable remote code loading manually.
In `@modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml`:
- Around line 4-6: The YAML enables trust_remote_code for a stock Llama model;
remove or set trust_remote_code to false to avoid executing arbitrary repo code.
Edit the model block that contains model_name_or_path: meta-llama/Llama-3.2-1B
and either delete the trust_remote_code line or change it to trust_remote_code:
false so the pipeline uses the standard transformers implementation rather than
allowing remote code execution.
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 138-149: The test currently writes both mix_hidden_states variants
into the same output directory causing runs to clobber each other; modify the
training output_dir construction (where eagle_output_dir /
f"eagle-tinyllama-cp{cp_size}" is used) to include the mix_hidden_states flag
(e.g., append `_mix{mix_hidden_states}` or similar) so each (cp_size,
mix_hidden_states) combination gets a unique checkpoint directory; update any
references that assume the old path (e.g., test_resume_training) to use the new
per-variant output_dir variable.
- Around line 269-273: Parametrize the trust_remote_code flag instead of
hardcoding True: add a test parameter (default False) named trust_remote_code to
the relevant test cases and use it when writing the model YAML dictionary
(replace the hardcoded "trust_remote_code": True with "trust_remote_code":
trust_remote_code) and when calling AutoConfig.from_pretrained (replace the
hardcoded trust_remote_code=True with trust_remote_code=trust_remote_code);
update only the specific test invocations that require remote code execution to
pass trust_remote_code=True. Ensure the new parameter is included in the pytest
parametrization for the test function(s) that build the YAML/model config so
local models keep trust_remote_code=False while remote-model cases explicitly
set it to True.
---
Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 258-259: The torch.load call that assigns model.eagle_module.d2t
from data_args.draft_vocab_cache should pass weights_only=True to avoid
executing pickled code; update the load call in the code that sets
model.eagle_module.d2t to use torch.load(data_args.draft_vocab_cache,
weights_only=True) so only tensor data is deserialized.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c971e970-8ce6-4555-9bd5-f56f417bbb15
📒 Files selected for processing (11)
examples/speculative_decoding/README.mdexamples/speculative_decoding/eagle_config.jsonexamples/speculative_decoding/fsdp_config.jsonexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/train_eagle3_and_export.shmodelopt_recipes/speculative_decoding/_base_eagle3.yamlmodelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_online.yamltests/examples/speculative_decoding/test_eagle.py
💤 Files with no reviewable changes (2)
- examples/speculative_decoding/eagle_config.json
- examples/speculative_decoding/fsdp_config.json
modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml
Outdated
Show resolved
Hide resolved
modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml
Outdated
Show resolved
Hide resolved
|
So does the yaml file encode all the information modelopt needs for the eagle3 training? |
Basically yes. The only exception is the accelerate configs (e.g. multinode settings). They need to be passed in addition to the yaml config, e.g.: I think they are orthogonal to the "recipe" and is more convenient to set in this way, since the node ip is often dynamic on slurm jobs. Do you think it's better to put it also in the yaml? |
ChenhanYu
left a comment
There was a problem hiding this comment.
PR Review: Refactor EAGLE3 training to YAML-based config and recipe system
Summary
Clean refactor that replaces ~250 lines of shell argument parsing in launch_train.sh with a YAML-based config system using OmegaConf. Config files support __base__ inheritance, and pre-configured recipes are shipped under modelopt_recipes/speculative_decoding/. The EagleArguments dataclass is removed — eagle config now passes directly from YAML to mtsp.convert(). Tests and README are updated accordingly. Net reduction: -74 lines. The direction is good.
Findings
1. Missing omegaconf dependency — Blocker
examples/speculative_decoding/main.py:47 — from omegaconf import OmegaConf
This is a new import, but omegaconf is not added to pyproject.toml extras or any requirements.txt. Users will get ImportError unless they happen to have it installed transitively (e.g., via Hydra). Needs to be added as a dependency.
2. trust_remote_code: true on Llama recipes — Security
modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml:7 and llama3_eagle_offline.yaml:7
Llama models don't require trust_remote_code. Given that PR #975 just put effort into removing hardcoded trust_remote_code=True throughout the codebase, shipping recipes with it enabled by default undermines that security improvement. Should be false for Llama recipes.
3. _parse_cli silently ignores unknown args — Migration footgun
examples/speculative_decoding/main.py:133 — args, _ = p.parse_known_args()
Users migrating from the old CLI (e.g., --eagle_config, --mode eagle3, --mix_hidden_states) will have their flags silently ignored with no error or deprecation warning. Consider logging the unknown args.
4. dp_shard_size: 0 magic sentinel — Edge case
examples/speculative_decoding/main.py:168-170 — If torch.cuda.device_count() returns 0 (CPU-only node, CUDA not visible), this produces 0. Guard with gpu_count = torch.cuda.device_count() or 1.
5. train_eagle3_and_export.sh YAML not self-contained
Line 103: ./launch_train.sh --config "$YAML_FILE" --model "$BASE_MODEL" — The generated YAML doesn't include model_name_or_path, so the config alone can't reproduce the training run despite the comment saying it's "preserved alongside the checkpoint."
6. Truncated help text
examples/speculative_decoding/main.py:111 — ar_validate_steps help string is "AR validation ." — incomplete.
7. Flat config merge can silently collide
main.py:159-163 merges model, data, and training dicts into one flat dict. If any sections share a key name, the later section silently wins.
Overall this is a well-structured refactor. Main action items: add omegaconf dependency, fix trust_remote_code on Llama recipes, and consider logging unknown CLI args for migration safety.
This is an AI-assisted review — human sign-off required before merging.
There was a problem hiding this comment.
Why don't we just move this file to tools/launcher/examples/moonshotai/Kimi-K2.5/?
There was a problem hiding this comment.
We can. It's not clear to me whether modelopt_recipes/ or tool/launcher/examples/ is the best place for these yamls. Curious what you think
There was a problem hiding this comment.
I think the quantization recipes in modelopt_recipes/ currently is in the right direction but not the ultimate format. It is used as one of the input to hf_ptq.py. This PR has encapsulated all arguments to main.py. The cleanest way is to provide it directly has a launcher yaml I think.
There was a problem hiding this comment.
Same. Why don't we move this to tools/launcher/examples/meta-llama/Llama-3.2-1B-Instruct?
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Comments 2, 3, 6 addressed. Other points seems fine to me |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tests/examples/speculative_decoding/test_eagle.py (2)
246-273:⚠️ Potential issue | 🔴 CriticalRemove hardcoded
trust_remote_code=True; make it explicit and default-safe.Hardcoding
trust_remote_code=Truein bothAutoConfig.from_pretrained(...)and generated YAML model configs is a CRITICAL security violation in this repo’s rules. Parameterize it and default toFalse, enablingTrueonly in explicitly justified test cases.Proposed direction
-@pytest.mark.parametrize( - ("model_source", "use_fake_base"), +@pytest.mark.parametrize( + ("model_source", "use_fake_base", "trust_remote_code"), [ - (None, False), - ("moonshotai/Kimi-K2.5", True), - ("moonshotai/Kimi-K2-Thinking", True), - ("MiniMaxAI/MiniMax-M2.5", True), + (None, False, False), + ("moonshotai/Kimi-K2.5", True, True), + ("moonshotai/Kimi-K2-Thinking", True, True), + ("MiniMaxAI/MiniMax-M2.5", True, True), ], ) def test_offline_eagle3_training(..., model_source, use_fake_base, trust_remote_code): ... - cfg = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) + cfg = transformers.AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) ... - "trust_remote_code": True, + "trust_remote_code": trust_remote_code,Apply the same pattern to
test_offline_resume_training_kimiinstead of hardcodingTrue.As per coding guidelines,
Flag trust_remote_code=True hardcoded for transformers model or tokenizer loading as CRITICAL security issue. Code should expose it as a caller-configurable parameter defaulting to False, not hardcode True.Also applies to: 306-320
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` around lines 246 - 273, Replace hardcoded trust_remote_code=True usages by adding a test-level parameter (default False) and passing that variable to transformers.AutoConfig.from_pretrained and into the generated training_cfg["model"]["trust_remote_code"]; specifically, introduce a local variable (e.g., trust_remote_code=False) at the top of the test and use it when calling AutoConfig.from_pretrained(...) and when building training_cfg["model"] instead of the literal True, and update any related test variants (e.g., test_offline_resume_training_kimi) to set trust_remote_code=True only when explicitly required.
139-139:⚠️ Potential issue | 🟠 MajorUpdate downstream consumers to the new
-mix...checkpoint path.Line 139 changed training output to
...-cp{cp_size}-mix{mix_hidden_states}, but later tests still readeagle-tinyllama-cp1(Line 187, Line 201), which can breaktest_ar_validateandtest_export_hf_checkpoint.Proposed fix
- "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1-mixFalse", ... - "--model_path", eagle_output_dir / "eagle-tinyllama-cp1", + "--model_path", eagle_output_dir / "eagle-tinyllama-cp1-mixFalse",As per coding guidelines,
All test coverage checks in PRs must pass for new features and examples.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` at line 139, Tests still reference the old checkpoint path name; update downstream consumers to use the new output_dir format that includes the mix suffix (constructed via eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-mix{mix_hidden_states}"). Locate usages in tests that read the checkpoint (notably the test functions test_ar_validate and test_export_hf_checkpoint) and replace hardcoded "eagle-tinyllama-cp1" (or similar) with the same formatted path logic or a derived variable from eagle_output_dir and cp_size/mix_hidden_states so both creation and consumption use the identical "-cp{cp_size}-mix{mix_hidden_states}" filename.
🧹 Nitpick comments (1)
examples/speculative_decoding/main.py (1)
131-134: Consider failing on unknown CLI args instead of silently ignoring them.Line 131-Line 134 currently accepts typos/legacy flags and continues. This can hide misconfiguration in the YAML migration path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 131 - 134, The CLI currently uses p.parse_known_args() which swallows typos/legacy flags; change to p.parse_args() or explicitly fail when unknown is non-empty: after calling parse_known_args(), if unknown is truthy call parser.error(...) or raise SystemExit with a clear message so the script fails fast instead of printing via print_rank_0 and proceeding; update the handling around parse_known_args()/parse_args() and remove the print_rank_0 fallback so callers relying on args.config and args.model get validated input.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 256-260: The code currently calls
os.path.isfile(data_args.draft_vocab_cache) which raises a TypeError when
data_args.draft_vocab_cache is None; add a guard to check for a truthy/non-None
draft_vocab_cache before performing os.path.isfile. Specifically, in the block
that assigns model.eagle_module.d2t via torch.load, first assert or raise a
clear error if data_args.draft_vocab_cache is falsy (e.g., None or empty string)
with a descriptive message, then proceed to call
os.path.isfile(data_args.draft_vocab_cache) and
torch.load(data_args.draft_vocab_cache, weights_only=True) only when the path
exists.
---
Duplicate comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 246-273: Replace hardcoded trust_remote_code=True usages by adding
a test-level parameter (default False) and passing that variable to
transformers.AutoConfig.from_pretrained and into the generated
training_cfg["model"]["trust_remote_code"]; specifically, introduce a local
variable (e.g., trust_remote_code=False) at the top of the test and use it when
calling AutoConfig.from_pretrained(...) and when building training_cfg["model"]
instead of the literal True, and update any related test variants (e.g.,
test_offline_resume_training_kimi) to set trust_remote_code=True only when
explicitly required.
- Line 139: Tests still reference the old checkpoint path name; update
downstream consumers to use the new output_dir format that includes the mix
suffix (constructed via eagle_output_dir /
f"eagle-tinyllama-cp{cp_size}-mix{mix_hidden_states}"). Locate usages in tests
that read the checkpoint (notably the test functions test_ar_validate and
test_export_hf_checkpoint) and replace hardcoded "eagle-tinyllama-cp1" (or
similar) with the same formatted path logic or a derived variable from
eagle_output_dir and cp_size/mix_hidden_states so both creation and consumption
use the identical "-cp{cp_size}-mix{mix_hidden_states}" filename.
---
Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 131-134: The CLI currently uses p.parse_known_args() which
swallows typos/legacy flags; change to p.parse_args() or explicitly fail when
unknown is non-empty: after calling parse_known_args(), if unknown is truthy
call parser.error(...) or raise SystemExit with a clear message so the script
fails fast instead of printing via print_rank_0 and proceeding; update the
handling around parse_known_args()/parse_args() and remove the print_rank_0
fallback so callers relying on args.config and args.model get validated input.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e99433a4-e316-439e-8872-ae66d1eeea96
📒 Files selected for processing (4)
examples/speculative_decoding/main.pymodelopt_recipes/speculative_decoding/llama3_eagle_offline.yamlmodelopt_recipes/speculative_decoding/llama3_eagle_online.yamltests/examples/speculative_decoding/test_eagle.py
✅ Files skipped from review due to trivial changes (1)
- modelopt_recipes/speculative_decoding/llama3_eagle_online.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt_recipes/speculative_decoding/llama3_eagle_offline.yaml
| if not os.path.isfile(data_args.draft_vocab_cache): | ||
| raise FileNotFoundError( | ||
| f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" | ||
| ) | ||
| model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) | ||
| model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) |
There was a problem hiding this comment.
Guard draft_vocab_cache before filesystem checks to avoid TypeError.
On Line 256, os.path.isfile(data_args.draft_vocab_cache) will throw when draft_vocab_cache is None. Fail fast with a clear message before checking file existence.
Proposed fix
# Load draft vocab cache if the draft model uses a compressed vocabulary
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
- if not os.path.isfile(data_args.draft_vocab_cache):
+ if not data_args.draft_vocab_cache:
+ raise ValueError(
+ "data.draft_vocab_cache must be set when draft_vocab_size < vocab_size."
+ )
+ if not os.path.isfile(data_args.draft_vocab_cache):
raise FileNotFoundError(
f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}"
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/main.py` around lines 256 - 260, The code
currently calls os.path.isfile(data_args.draft_vocab_cache) which raises a
TypeError when data_args.draft_vocab_cache is None; add a guard to check for a
truthy/non-None draft_vocab_cache before performing os.path.isfile.
Specifically, in the block that assigns model.eagle_module.d2t via torch.load,
first assert or raise a clear error if data_args.draft_vocab_cache is falsy
(e.g., None or empty string) with a descriptive message, then proceed to call
os.path.isfile(data_args.draft_vocab_cache) and
torch.load(data_args.draft_vocab_cache, weights_only=True) only when the path
exists.
| trust_remote_code: true | ||
|
|
||
| data: | ||
| offline_data_path: <path to offline data> |
There was a problem hiding this comment.
how do we get the data prepared?
There was a problem hiding this comment.
All of these need to be provided explicitly in the launcher.
modelopt_recipes/speculative_decoding/kimi_k25_eagle_offline.yaml
Outdated
Show resolved
Hide resolved
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
examples/speculative_decoding/launch_train.sh (1)
59-74:⚠️ Potential issue | 🔴 CriticalBuild the accelerate command as an array instead of
sh -c.Interpolating
$CONFIG_FILE,$HEAD_NODE_IP, and${EXTRA_ARGS[*]}into one shell string reintroduces word-splitting/injection, and the single-node path computesTOTAL_GPUwithout ever using it in the final command.Suggested fix
-# Multi-node routing args (accelerate only; training config comes from the YAML) -MULTI_NODE_ARGS="" -if [[ "$NUM_NODES" != "1" ]]; then - MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ - --num_machines $NUM_NODES \ - --machine_rank $SLURM_PROCID \ - --rdzv_backend c10d \ - --main_process_ip $HEAD_NODE_IP \ - --main_process_port 29500" -fi +# Build the launch command as an array so overrides stay quoted. +LAUNCH_ARGS=(accelerate launch --mixed_precision bf16 --num_processes "$TOTAL_GPU") +if [[ "$NUM_NODES" != "1" ]]; then + LAUNCH_ARGS+=( + --num_machines "$NUM_NODES" + --machine_rank "$SLURM_PROCID" + --rdzv_backend c10d + --main_process_ip "$HEAD_NODE_IP" + --main_process_port 29500 + ) +fi +LAUNCH_ARGS+=("${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}") @@ -sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}" +"${LAUNCH_ARGS[@]}"#!/bin/bash # Confirm that the current implementation still uses `sh -c` # and only wires `--num_processes` through the multi-node string. rg -n -C2 'TOTAL_GPU|MULTI_NODE_ARGS|num_processes|sh -c' examples/speculative_decoding/launch_train.sh🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/launch_train.sh` around lines 59 - 74, The accelerate launch command should be built as a shell array and executed directly instead of using sh -c with a single interpolated string to avoid word-splitting/injection; update the script to construct CMD as an array including "accelerate", "launch", "--mixed_precision", "bf16", plus the expanded MULTI_NODE_ARGS elements (only when NUM_NODES != 1), then append "${SCRIPT_DIR}/main.py", "--config", "$CONFIG_FILE", and each element from EXTRA_ARGS, and finally run it with exec "${CMD[@]}"; also remove or use the computed TOTAL_GPU (now unused) so you don’t compute --num_processes only in MULTI_NODE_ARGS without applying it in the single-node case.examples/speculative_decoding/main.py (1)
246-252:⚠️ Potential issue | 🟠 MajorGuard
draft_vocab_cachebefore the filesystem check.When compressed vocab is enabled and
data.draft_vocab_cacheis unset, this path raisesTypeErrorfromos.path.isfile()instead of a clear config error.Suggested fix
# Load draft vocab cache if the draft model uses a compressed vocabulary if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: + if not data_args.draft_vocab_cache: + raise ValueError( + "data.draft_vocab_cache must be set when draft_vocab_size < vocab_size." + ) if not os.path.isfile(data_args.draft_vocab_cache): raise FileNotFoundError( f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 246 - 252, The filesystem check for draft_vocab_cache calls os.path.isfile(data_args.draft_vocab_cache) without ensuring data_args.draft_vocab_cache is set, causing a TypeError; update the block that handles compressed vocab (references: model.eagle_config.draft_vocab_size, model.eagle_config.vocab_size, model.eagle_module.d2t, data_args.draft_vocab_cache) to first verify data_args.draft_vocab_cache is truthy (not None/empty) and if not raise a clear config error (e.g., ValueError) indicating the draft_vocab_cache must be provided; only then call os.path.isfile(...) and load the tensor with torch.load when the file exists.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 30-39: Add parsing for the documented --model override so it
doesn't fall into EXTRA_ARGS: add a case branch for --model* (similar to
--config* etc.) that consumes either --model=... or the next arg and stores it
in a variable (e.g., MODEL_PATH) instead of EXTRA_ARGS; then ensure the launcher
uses MODEL_PATH to set model.model_name_or_path before launch. Update the case
block (near CONFIG_FILE, NUM_NODES, HEAD_NODE_IP, EXTRA_ARGS) to handle --model*
and ensure the code path that configures the model reads MODEL_PATH and assigns
it to model.model_name_or_path.
In `@examples/speculative_decoding/main.py`:
- Around line 136-149: The function _load_config currently only loads the target
file into merged and then flattens it, which drops any inherited defaults
referenced via a __base__ recipe; fix this by resolving __base__ entries before
converting to a plain dict: recursively detect __base__ in merged (string or
list), load each base YAML with OmegaConf.load (resolving relative paths against
the target config path), merge bases into a combined base_conf using
OmegaConf.merge, then merge the child merged onto that base_conf, repeating
until no __base__ remains; only after fully resolving inheritance apply the
dotlist overrides with OmegaConf.from_dotlist and finally call
OmegaConf.to_container to produce cfg and return the eagle_cfg as before (update
references in the code around merged, OmegaConf.load, OmegaConf.merge,
OmegaConf.from_dotlist, and cfg).
In `@examples/speculative_decoding/README.md`:
- Line 261: The README uses the wrong top-level key for the draft vocab size;
update the YAML example to the correct nested key
eagle.eagle_architecture_config.draft_vocab_size (keep data.draft_vocab_cache as
shown) and search/update any other documentation or examples that reference
eagle_architecture_config.draft_vocab_size so they use
eagle.eagle_architecture_config.draft_vocab_size to match the runtime schema.
---
Duplicate comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 59-74: The accelerate launch command should be built as a shell
array and executed directly instead of using sh -c with a single interpolated
string to avoid word-splitting/injection; update the script to construct CMD as
an array including "accelerate", "launch", "--mixed_precision", "bf16", plus the
expanded MULTI_NODE_ARGS elements (only when NUM_NODES != 1), then append
"${SCRIPT_DIR}/main.py", "--config", "$CONFIG_FILE", and each element from
EXTRA_ARGS, and finally run it with exec "${CMD[@]}"; also remove or use the
computed TOTAL_GPU (now unused) so you don’t compute --num_processes only in
MULTI_NODE_ARGS without applying it in the single-node case.
In `@examples/speculative_decoding/main.py`:
- Around line 246-252: The filesystem check for draft_vocab_cache calls
os.path.isfile(data_args.draft_vocab_cache) without ensuring
data_args.draft_vocab_cache is set, causing a TypeError; update the block that
handles compressed vocab (references: model.eagle_config.draft_vocab_size,
model.eagle_config.vocab_size, model.eagle_module.d2t,
data_args.draft_vocab_cache) to first verify data_args.draft_vocab_cache is
truthy (not None/empty) and if not raise a clear config error (e.g., ValueError)
indicating the draft_vocab_cache must be provided; only then call
os.path.isfile(...) and load the tensor with torch.load when the file exists.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 97c4d6d2-a3e2-4cb8-821a-843118b66a02
📒 Files selected for processing (5)
examples/speculative_decoding/README.mdexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pyexamples/speculative_decoding/train_eagle3_and_export.shmodelopt_recipes/speculative_decoding/_base_eagle3.yaml
✅ Files skipped from review due to trivial changes (1)
- modelopt_recipes/speculative_decoding/_base_eagle3.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/speculative_decoding/train_eagle3_and_export.sh
| This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. | ||
|
|
||
| Then, simply set `{"draft_vocab_size":32000}` in `eagle_config.json` and include `--draft_vocab_cache <path_to_d2t.pt>` when running `./launch_train.sh`. The draft model will use this provided vocab table during training and export. | ||
| Then, set `eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: <path_to_d2t.pt>` in your YAML. The draft model will use this provided vocab table during training and export. |
There was a problem hiding this comment.
Use the full nested YAML path for draft_vocab_size.
The runtime schema nests this under eagle, so eagle_architecture_config.draft_vocab_size reads like a top-level key. Use eagle.eagle_architecture_config.draft_vocab_size to match the actual config structure.
Suggested fix
-Then, set `eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: <path_to_d2t.pt>` in your YAML. The draft model will use this provided vocab table during training and export.
+Then, set `eagle.eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: <path_to_d2t.pt>` in your YAML. The draft model will use this provided vocab table during training and export.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Then, set `eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: <path_to_d2t.pt>` in your YAML. The draft model will use this provided vocab table during training and export. | |
| Then, set `eagle.eagle_architecture_config.draft_vocab_size: 32000` and `data.draft_vocab_cache: <path_to_d2t.pt>` in your YAML. The draft model will use this provided vocab table during training and export. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/README.md` at line 261, The README uses the
wrong top-level key for the draft vocab size; update the YAML example to the
correct nested key eagle.eagle_architecture_config.draft_vocab_size (keep
data.draft_vocab_cache as shown) and search/update any other documentation or
examples that reference eagle_architecture_config.draft_vocab_size so they use
eagle.eagle_architecture_config.draft_vocab_size to match the runtime schema.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/train_eagle3_and_export.sh (1)
53-58:⚠️ Potential issue | 🟡 MinorQuote variable expansions to handle paths with spaces.
$OUTPUT_DIRand$EXPORT_PATHare unquoted, which will cause word-splitting issues ifBASE_MODEL(a user-supplied path) contains spaces.🔧 Suggested fix
-python scripts/ar_validate.py --model_path $OUTPUT_DIR +python scripts/ar_validate.py --model_path "$OUTPUT_DIR" ... -python scripts/export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH +python scripts/export_hf_checkpoint.py --model_path "$OUTPUT_DIR" --export_path "$EXPORT_PATH"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/train_eagle3_and_export.sh` around lines 53 - 58, The script fails to quote variable expansions which breaks on paths with spaces: wrap expansions in double quotes when assigning and passing variables — set EXPORT_PATH="export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)" and pass "$OUTPUT_DIR" to scripts/ar_validate.py and "$OUTPUT_DIR" and "$EXPORT_PATH" to scripts/export_hf_checkpoint.py (also quote any echo or other uses of these vars) so that OUTPUT_DIR, EXPORT_PATH, and MODEL_BASENAME are safely handled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/train_eagle3_and_export.sh`:
- Line 36: The assignment to BASE_CFG uses GNU-only readlink -f which breaks on
macOS; replace the readlink usage in the BASE_CFG assignment by computing the
script directory portably (e.g., use cd and pwd -P) so it works on BSD/macOS and
Linux: change the expression that currently uses readlink -f "$0" to a portable
form such as cd "$(dirname "$0")" && pwd -P (then append the relative path) so
the BASE_CFG variable is constructed reliably across platforms.
---
Outside diff comments:
In `@examples/speculative_decoding/train_eagle3_and_export.sh`:
- Around line 53-58: The script fails to quote variable expansions which breaks
on paths with spaces: wrap expansions in double quotes when assigning and
passing variables — set EXPORT_PATH="export/${MODEL_BASENAME}-$(date
+%Y%m%d_%H%M)" and pass "$OUTPUT_DIR" to scripts/ar_validate.py and
"$OUTPUT_DIR" and "$EXPORT_PATH" to scripts/export_hf_checkpoint.py (also quote
any echo or other uses of these vars) so that OUTPUT_DIR, EXPORT_PATH, and
MODEL_BASENAME are safely handled.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0159f2c5-bc33-4d48-a892-624c9e0ab353
📒 Files selected for processing (2)
examples/speculative_decoding/README.mdexamples/speculative_decoding/train_eagle3_and_export.sh
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/speculative_decoding/README.md
| OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) | ||
| mkdir -p "$OUTPUT_DIR" | ||
|
|
||
| BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml" |
There was a problem hiding this comment.
readlink -f is not portable to macOS.
The -f (canonicalize) flag is a GNU extension and will fail on macOS where readlink lacks this option.
🔧 Portable alternative
-BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml"
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd -P)"
+BASE_CFG="$SCRIPT_DIR/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| BASE_CFG="$(dirname "$(readlink -f "$0")")/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml" | |
| SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd -P)" | |
| BASE_CFG="$SCRIPT_DIR/../../modelopt_recipes/speculative_decoding/_base_eagle3.yaml" |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/train_eagle3_and_export.sh` at line 36, The
assignment to BASE_CFG uses GNU-only readlink -f which breaks on macOS; replace
the readlink usage in the BASE_CFG assignment by computing the script directory
portably (e.g., use cd and pwd -P) so it works on BSD/macOS and Linux: change
the expression that currently uses readlink -f "$0" to a portable form such as
cd "$(dirname "$0")" && pwd -P (then append the relative path) so the BASE_CFG
variable is constructed reliably across platforms.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Yeah, so in the end state, not for this PR, We should just do something like this: the recipe part is everything that do not need to change for the user to reproduce the optimization pipeline, and the launcher-config is everything the user need to modify accordingly to their system setup. |
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Refactors EAGLE3 training to use a single base YAML config with OmegaConf dotlist overrides.
Type of change: Refactor
Changes
modelopt_recipes/speculative_decoding/_base_eagle3.yamlfor all EAGLE3 training; removed per-model child YAMLs.launch_train.shaccepts--config <yaml>plus dotlist overrides (e.g.model.model_name_or_path=xxx).__base__YAML inheritance logic frommain.py.dp_shard_sizedefault changed from0sentinel toNonefor clarity.eagle_config.jsonandfsdp_config.json; architecture config is now nested undereagle.eagle_architecture_configin YAML.train_eagle3_and_export.shnow uses base YAML + dotlist instead of generating a temporary YAML.Usage